# -*- coding: utf-8 -*-
"""Knowledge_Knob_embeddings.ipynb
"""

#Mount google drive
from google.colab import drive

#mount google drive
drive.mount("/content/gdrive")

#install library required to read knowledge graph
!pip install rdflib

#read knowledge graph from path in drive location
from rdflib import Graph as RDFGraph

#path to knowledge graph
path = "gdrive/MyDrive/AGNNs/"

#read and parse knowledge graph file in ttl format
kg_file = path+"small_copy.ttl"
kg = RDFGraph()
kg.parse(kg_file, format='turtle')

#import tqdm library to track for loop progress
from tqdm import tqdm

#set a test flag for debugging
TEST = True

#get all questions from a knowlege graph kg

def get_questions(kg):
  kg_questions = {} #initialize place holder

  print ('processing kg triples ... ') #print statement for clarity in output

  for s,p,o in tqdm(kg): #use the hasQuestion relation to get question
    if 'hasQuestion' in p: #the object o is question
      question_iri = o
      for s2,p2,o2 in kg: #extract question string, use the hasStringValue relation to get question string
        if s2 == question_iri and 'hasStringValue' in p2:
          kg_questions[question_iri] = o2
  return kg_questions

#execute function to get all questions and store it in kg_questions variable
kg_questions = get_questions(kg)

#get all correct choices for the questions
def get_correct_choices(kg,kg_questions):

  correct_choices= dict() #place holder for all question's correct choices
  
  print ('processing all questions ... ')
  for question in tqdm(kg_questions):
    choices = [] #place holder for all question choices
    question_correct_choices = [] #place holder for correct question choices
    correct_choices[question] = [] #instantiate in place holder for all question correct choices
    for s,p,o in kg:
      #get answers to the question
      if s == question and 'hasChoices' in p:
        choices += [o]
    for choice in choices:
      for s,p,o in kg:
        #get correct choices
        if choice == s and 'hasStringValue' in p and 'correct' in o:
          question_correct_choices += [s]
    for choice in question_correct_choices:
      for s,p,o in kg:
        #get correct choice program and add to dict
        if choice == s and 'hasProgram' in p:
          correct_choices[question] += [str(o)]

  return correct_choices #return all question's correct choices

#execute function to get all correct choices for all questions from the kg
kg_answers = get_correct_choices(kg,kg_questions)

#get knowledge relevant to the question from kg
def get_knowledge(question,kg):

  relevant_knowledge = []
  video_id = question.split('__')[0].split('#')[-1] #get video id

  for s,p,o in kg: #get all the knowledge from the video
    if 'hasStringValue' in p or 'hasProgram' in p:
      continue #we dont want the string value or program value
    if 'hasQuestion' in p:
      continue #dont care about other questions
    if 'hasChoices' in p: #dont care about question choices
      continue
    if 'type' in p and 'owl' in o: #omit owl types
      continue 
    if 'type' in p and 'trafficmonitoring' in o: #omit parent ontology types
      continue 
    if video_id in s and ('Scene' in s or 'Observation' in s or 'Point' in s): #subject has something to do with the video
      relevant_knowledge.append((s,p,o))

  return relevant_knowledge

#test if get_knowlege function works with random question
if TEST: #check if TEST FLAG is set

  #get random question and call function
  from random import choice
  question = choice(list(kg_questions.keys()))
  relevant_knowledge = get_knowledge(question,kg)
  print (relevant_knowledge[:2]) #print a small part of it

#write function to format KG for use by embeddin library
#i.e., use list of lists format [[s,p,o], ... , ]

def format_kg(rkg):

  triple_list  = [[str(item) for item in list(triple)] for triple in kg]
  return triple_list

#test if formatting function works
if TEST: #check if test flag is set

  #get random question and call function
  from random import choice
  question = choice(list(kg_questions.keys()))
  relevant_knowledge = get_knowledge(question,kg)
  formatted_knowledge = format_kg(relevant_knowledge)
  print (formatted_knowledge[:2]) #print a small part of it

#install library for knowledge graph embeddings
!pip install ampligraph

#define get_embeddings function as wrapper around the library KGE method, KGE = Knowledge Graph Embedding
import numpy as np
from ampligraph.latent_features import ScoringBasedEmbeddingModel

def get_embeddings(rkg,
                   method = 'TransE',
                   k = 5,
                   epochs = 5):
  X = np.array(rkg)
  model = ScoringBasedEmbeddingModel(k=k, eta=1, scoring_type=method)
  model.compile(optimizer='adam', loss='nll')
  model.fit(X, epochs=epochs)
  return model

#test if get_embedding function works
#embedding_model.get_embeddings([entity],embedding_type='e'), to get entity embeddings, use r for relationships 
if TEST: #check if test flag is set

  #get random question and call function(s)
  from random import choice
  question = choice(list(kg_questions.keys()))
  relevant_knowledge = get_knowledge(question,kg)
  formatted_knowledge = format_kg(relevant_knowledge)
  print (len(formatted_knowledge))
  embedding_model = get_embeddings(formatted_knowledge[:2],k = 200,epochs = 1)

if TEST:
  test_question = question

import numpy as np
import networkx as nx

def get_question_matrix(question,kg):

  rkg = get_knowledge(question,kg)

  adj_matrix = {}
  adj_index = {}
  nrkg = len(rkg)
  er_counter = 0

  for i in tqdm(range(nrkg)):
    triple = rkg[i]; s,p,o = triple[0],triple[1],triple[-1]
    #forward connections
    adj_matrix[(er_counter,er_counter+1)] = 1; adj_index[er_counter], adj_index[er_counter+1] = str(s), str(p)
    adj_matrix[(er_counter+1,er_counter+2)] = 1; adj_index[er_counter+2] = str(o)
    #inverse connections
    adj_matrix[(er_counter+3,er_counter)] = 1; adj_index[er_counter+3] = str(p)+'_inv'
    adj_matrix[(er_counter+2,er_counter+3)] = 1
    er_counter += 4

  DG = nx.DiGraph(list(adj_matrix.keys()))
  #DG = nx.DiGraph([(1, 2), (2, 3)])
  edges = list(nx.transitive_closure(DG, reflexive=False).edges()) #inferred relationships

  adj_M = np.zeros((er_counter,er_counter))

  print ('constructing adjacency matrix with inferred edges ... ')
  for edge in tqdm(edges):
    adj_M[edge[0],edge[1]] = 1.0

  return adj_M, er_counter

if TEST:
  adjacency_matrix, n_nodes = get_question_matrix(test_question,kg)

if TEST:
  import torch
  import torch.nn as nn
  import torch.nn.functional as F

  emb_size = 200
  max_nodes = 100000
  adjacency_matrix, n_nodes = get_question_matrix(test_question,kg); print ('no. of nodes', n_nodes)
  embedding_matrix = nn.Embedding(max_nodes,emb_size)

  adjacency_tensor = torch.from_numpy(adjacency_matrix).float(); print (adjacency_tensor.size())

#get cuda device and check no. of GPUs
import torch
device = ('cuda' if torch.cuda.is_available() else 'cpu')
if device == 'cuda':  
  print ('no. of GPUs: ',torch.cuda.device_count())
else:
  print ('No GPU found ... ')

if TEST:
  node_tensors = embedding_matrix(torch.arange(n_nodes)); print (node_tensors.size())
  adjacency_tensor.to(device); node_tensors.to(device)
  adjacency_tensor = adjacency_tensor @ node_tensors

"""#Main program starts here"""

from random import sample, shuffle, random
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from rdflib import Graph as RDFGraph
import numpy as np
import networkx as nx
import gc

#define knowledge graph utilities class
class CLEVRER_KG_Utils(object):

  @staticmethod
  def parseKG_from_file(kg_file): #parse CLEVRER KG in turtle format using rdflib and return
    kg = RDFGraph()
    kg.parse(kg_file, format='turtle')
    return kg

  @staticmethod
  def get_questions(kg): #get all questions from the CLEVRER KG
    kg_questions = {} #initialize place holder

    print ('processing kg triples ... ') #print statement for clarity in output

    for s,p,o in tqdm(kg): #use the hasQuestion relation to get question
      if 'hasQuestion' in p: #the object o is question
        question_iri = o
        for s2,p2,o2 in kg: #extract question string, use the hasStringValue relation to get question string
          if s2 == question_iri and 'hasStringValue' in p2:
            kg_questions[question_iri] = o2
    del question_iri; del s; del p; del o; del s2; del p2; del o2
    gc.collect()
    return kg_questions

  @staticmethod
  def get_correct_choices(kg,kg_questions): #get all correct choices for the questions

    correct_choices= dict() #place holder for all question's correct choices
    
    #print ('processing all questions ... ')
    for question in kg_questions:
      choices = [] #place holder for all question choices
      question_correct_choices = [] #place holder for correct question choices
      correct_choices[question] = [] #instantiate in place holder for all question correct choices
      for s,p,o in kg:
        #get answers to the question
        if s == question and 'hasChoices' in p:
          choices += [o]
      for choice in choices:
        for s,p,o in kg:
          #get correct choices
          if choice == s and 'hasStringValue' in p and 'correct' in o:
            question_correct_choices += [s]
      for choice in question_correct_choices:
        for s,p,o in kg:
          #get correct choice program and add to dict
          if choice == s and 'hasProgram' in p:
            correct_choices[question] += [str(o)]

    del question; del choices; del question_correct_choices; del s; del p; del o
    gc.collect()
    return correct_choices #return all question's correct choices

  @staticmethod
  def get_knowledge(question,kg): #get knowledge relevant to the question from kg

    relevant_knowledge = []
    video_id = question.split('__')[0].split('#')[-1] #get video id

    for s,p,o in kg: #get all the knowledge from the video
      if 'hasStringValue' in p or 'hasProgram' in p:
        continue #we dont want the string value or program value
      if 'hasQuestion' in p and o!=question:
        continue #dont care about other questions
      if 'hasChoices' in p: #dont care about question choices
        continue
      if 'type' in p and 'owl' in o: #omit owl types
        continue 
      if 'type' in p and 'trafficmonitoring' in o: #omit parent ontology types
        continue 
      if video_id in s and ('Scene' in s or 'Observation' in s or 'Point' in s): #subject has something to do with the video
        relevant_knowledge.append((s,p,o))
    
    del s; del p; del o
    gc.collect()
    return relevant_knowledge

  @staticmethod
  def get_question_matrix(question,
                          kg,
                          kg_perc = 0.01): #get transitive closure matrix corresponding to the question

    get_knowledge = CLEVRER_KG_Utils.get_knowledge
    rkg = get_knowledge(question,kg) #get the relevant question knowledge

    adj_matrix = {} #initialize adjacency matrix as dictionary
    adj_index = {} #initialize a dictionary index for quick lookup
    nrkg = len(rkg) #no. of triples in the relevant knowledge
    er_counter = 0 #keep count of total no. of nodes and relationships

    for i in range(nrkg): #construct adjacency matrix
      if random() > kg_perc:
        continue
      triple = rkg[i]; s,p,o = triple[0],triple[1],triple[-1] #get triple
      #forward connections
      adj_matrix[(er_counter,er_counter+1)] = 1; adj_index[er_counter], adj_index[er_counter+1] = str(s), str(p)
      adj_matrix[(er_counter+1,er_counter+2)] = 1; adj_index[er_counter+2] = str(o)
      #inverse connections
      adj_matrix[(er_counter+3,er_counter)] = 1; adj_index[er_counter+3] = str(p)+'_inv'
      adj_matrix[(er_counter+2,er_counter+3)] = 1
      er_counter += 4

    DG = nx.DiGraph(list(adj_matrix.keys())) #convert to networkx digraph to compute transitive closure
    edges = list(nx.transitive_closure(DG, reflexive=False).edges()) #inferred relationships through transitive closure computation

    adj_M = np.zeros((er_counter,er_counter)) #initialize adjacency matrix as a 2D array of zeros, TODO: is this necessary?

    #print ('constructing adjacency matrix with inferred edges ... ') #fill in the ones corresponding to edges in the transitive closure set
    for edge in edges:
      adj_M[edge[0],edge[1]] = 1.0

    del adj_matrix; del adj_index; del nrkg; del i; del triple; del edges; del DG
    gc.collect()

    return adj_M, er_counter #return the adjacency matrix, and no. of nodes and relationships

#define tokenizer
class Tokenizer(object):

  def __init__(self,
               kg_questions,
               kg_choices):
    #this will be a character level tokenizer
    self.chars = set() #initializer character token list
    self.char_index = dict() #index dictionary for efficient lookup

    print ('Initializing tokenizer ... ')
    for question in tqdm(kg_questions):
      question_chars = list(kg_questions[question]) #add question characters
      for char in question_chars:
        self.chars.add(char)
      #format answer to separate question from choice 'CH', and primitive within the choice '<E>'
      formatted_choices = 'CH:'+('^_^<E>^_^'.join(kg_choices[question]))+'^_^<E>'
      choice_chars = list(formatted_choices) #add choice program characters
      for char in choice_chars:
        self.chars.add(char)
    del char; del formatted_choices; del choice_chars

    self.chars = list(self.chars) #convert to list to create index
    self.vocab_size = len(self.chars)
    for i in range(self.vocab_size): #create index
      char = self.chars[i]; self.char_index[char] = i
    gc.collect()

  def encode(self,
             string):
    
    #return index of each character from the index dictionary
    return [self.char_index[char] for char in list(string)]

  def decode(self,
             encoding):
    
    #lookup self.chars for each index in the encoding, 
    #join the chars into string and return
    return ''.join([self.chars[i] for i in encoding])

#define dataloader class
class Dataloader(object):

  def __init__(self,
               tokenizer):
    
    self.tokenizer = tokenizer #store tokenizer

  def get_batch(self,
                kg_questions,
                kg_answers,
                n = None):

    X, Y = [],[] #place holders for the data
    question_batch = list(kg_questions.keys()); shuffle(question_batch) #use all questions (shuffled)
    
    #print ('processing questions ... ')
    for question in question_batch:
      #get question encoding
      question_encoding = self.tokenizer.encode(kg_questions[question])
      #format answer to separate question from choice 'CH', and primitive within the choice '<E>'
      formatted_choices = 'CH:'+('^_^<E>^_^'.join(kg_answers[question]))+'^_^<E>'
      choice_encodings = self.tokenizer.encode(formatted_choices)
      #concatenate the two because we want to generate both
      full_encoding = question_encoding + choice_encodings
      context_size = len(full_encoding)
      #create datapoint with question, it's encoding, and choice encodings
      for t in range(context_size-1):
        x, y = full_encoding[:t+1], full_encoding[t+1]
        X += [(question,x)]; Y += [y]

    #consolidate all data points for all questions
    if n is None:
      self.data = list([list(item) for item in zip(X,Y)])
      del question_encoding; del formatted_choices; del choice_encodings; del full_encoding; del context_size; del x; del y; del X; del Y
    else:
      self.data = sample(list([list(item) for item in zip(X,Y)]),n)
      del question_encoding; del formatted_choices; del choice_encodings; del full_encoding; del context_size; del x; del y; del X; del Y
    gc.collect()

#define class for general tensor utilities
class Tensor_Utils(object):

  @staticmethod
  def normed(T): #function to normalize the tensor T

    #get norm of the tensor
    norm = torch.linalg.norm(T)

    #divide by norm
    return (torch.div(T,norm.item()))

#define generator class that will input question,
#and generate the correct choice programs
class Generator(nn.Module):

  def __init__(self,
               vocab_size = None,
               emb_size = None,
               context_size = None,
               n_heads = None,
               kg = None,
               max_nodes = 100000): #it uses a single multiheaded self-attention block

    super().__init__() #call superclass constructor

    #store config data
    self.vocab_size = vocab_size
    self.emb_size = emb_size
    self.context_size = context_size
    self.n_heads = n_heads
    self.kg = kg
    self.max_nodes = max_nodes

    #embedding layer for knowledge entities and relations, both represented as nodes in a graph (adjacency matrix)
    self.knowledge_embedding_matrix = nn.Embedding(max_nodes,emb_size)

    #embedding layer with position encodings
    self.embeddings = nn.Embedding(self.vocab_size, self.emb_size)
    self.pos_embeddings = nn.Embedding(self.context_size, self.emb_size)

    #query, key, value matrix and self-attention operation
    self.query = nn.Linear(self.emb_size,self.emb_size,bias=False)
    self.key = nn.Linear(self.emb_size,self.emb_size,bias=False)
    self.value = nn.Linear(self.emb_size,self.emb_size,bias=False)
    self.multihead_attn = nn.MultiheadAttention(self.emb_size,self.n_heads)

    #classification head
    self.head = nn.Linear(self.emb_size,self.vocab_size)
    self.attn_output_weights = None #place holder for attentino weights

  def forward(self,
              question_data,
              kg_perc = 0.01):
    
    kg = self.kg #shorthand
    question_encoding = question_data[0][1][-self.context_size:] #get tokenized question
    question = question_data[0][0] #get question
    adj_matrix, n_nodes = CLEVRER_KG_Utils.get_question_matrix(question,kg,kg_perc=kg_perc) #get adjacency matrix and no. of graph nodes
    adj_matrix = torch.from_numpy(adj_matrix).float() #convert to torch tensor (with shared memory address)
    adj_matrix.to_sparse()
    node_vectors = self.knowledge_embedding_matrix(torch.arange(n_nodes)) #get embedding vectors for the nodes
    adj_matrix = adj_matrix @ node_vectors #compute adj_matrix weighted node vectors
    question_tensor = torch.tensor(question_encoding) #convert to tensor
    n_tokens = len(question_tensor)
    #adj_matrix = torch.row_stack([torch.mean(torch.pow(adj_matrix,k),dim=0) for k in range(n_tokens)]) #compute polynomial orders
    question_embedding = self.embeddings(question_tensor) #get question embedding with position encodings
    question_embedding += self.pos_embeddings(torch.arange(n_tokens))
    question_embedding += adj_matrix.mean(dim=0)
    #question_embedding = torch.row_stack((question_embedding,adj_matrix))
    Q = self.query(question_embedding) #multi-headed self-attention computation
    K = self.key(question_embedding)
    V = self.value(question_embedding)
    question_embedding, self.attn_output_weights = self.multihead_attn(Q,K,V)
    
    logits = F.leaky_relu(self.head(question_embedding))[-1] #get logits by extracting last column of size vocab_size and return
    return logits

  def train(self,
            kg_questions,
            kg_choices,
            dataloader_object,
            batch_size = 32,
            epochs = 100,
            kg_perc = 0.01): #training function for the generator
    
    optimizer = torch.optim.AdamW(self.parameters()) #initialize optimizer

    print ('Starting training loop ... ')
    for i in tqdm(range(epochs)): #training loop
      dl.get_batch(kg_questions, kg_choices, n = batch_size) #get a batch of data
      n_batch = len(dl.data) #calculate no. of data points
      loss = F.cross_entropy #set loss to cross entropy loss

      batch_loss = 0.0 #initialize batch_loss
      #print ('processing batch ... ')
      for j in range(n_batch):
        data_point = dl.data[j] #get datapoint
        x, y = data_point, data_point[-1] #get (x,y) pair
        logits = self(x,kg_perc=kg_perc) #compute forward pass
        #compute one hot encoding for targets
        targets = [0.0]*self.vocab_size; targets[y] = 1.0
        targets = torch.tensor(targets) #convert to tensor
        batch_loss += loss(logits,targets) #add to total batch loss

      batch_loss /= n_batch #compute average batch loss
      print ('batch loss: ',batch_loss.item()) #print batch loss to check convergence
      #perform optimization step
      batch_loss.backward()
      nn.utils.clip_grad_norm_(self.parameters(), 1.0)
      optimizer.step()
      optimizer.zero_grad()
      del dl.data; del loss; del batch_loss; del n_batch; del i; del j; del logits; del targets;
    gc.collect()

"""#Unit tests below"""

TEST_KNOWLEDGE_GRAPH_PARSER = False
if TEST_KNOWLEDGE_GRAPH_PARSER:
  kg = CLEVRER_KG_Utils.parseKG_from_file("gdrive/MyDrive/AGNNs/small_copy.ttl")
  triples = [(s,p,o) for (s,p,o) in kg]
  from random import choice
  print (choice(triples))
  del kg; del triples; gc.collect()

TEST_GET_KG_QUESTIONS = False
if TEST_GET_KG_QUESTIONS:
  kg = CLEVRER_KG_Utils.parseKG_from_file("gdrive/MyDrive/AGNNs/small_copy.ttl")
  kg_questions = CLEVRER_KG_Utils.get_questions(kg)
  from random import choice
  print ('\n'+choice(list(kg_questions.values())))
  del kg; del kg_questions; gc.collect()

TEST_GET_KG_QUESTION_CHOICES = False
if TEST_GET_KG_QUESTION_CHOICES:
  kg = CLEVRER_KG_Utils.parseKG_from_file("gdrive/MyDrive/AGNNs/small_copy.ttl")
  kg_questions = CLEVRER_KG_Utils.get_questions(kg)
  kg_choices = CLEVRER_KG_Utils.get_correct_choices(kg,kg_questions)
  from random import choice
  random_question = choice(list(kg_questions.keys()))
  print (kg_choices[random_question])
  del kg; del kg_questions; del kg_choices; del random_question; gc.collect()

TEST_GET_QUESTION_KNOWLEGE = False
if TEST_GET_QUESTION_KNOWLEGE:
  kg = CLEVRER_KG_Utils.parseKG_from_file("gdrive/MyDrive/AGNNs/small_copy.ttl")
  kg_questions = CLEVRER_KG_Utils.get_questions(kg)
  from random import choice
  random_question = choice(list(kg_questions.keys()))
  print (CLEVRER_KG_Utils.get_knowledge(random_question,kg))
  del kg; del kg_questions; del random_question; gc.collect()

#def get_question_matrix(question,kg)
TEST_GET_QUESTION_MATRIX = False
if TEST_GET_QUESTION_MATRIX:
  kg = CLEVRER_KG_Utils.parseKG_from_file("gdrive/MyDrive/AGNNs/small_copy.ttl")
  kg_questions = CLEVRER_KG_Utils.get_questions(kg)
  from random import choice
  random_question = choice(list(kg_questions.keys()))
  print (CLEVRER_KG_Utils.get_question_matrix(random_question,kg))
  del kg; del kg_questions; del random_question; gc.collect()

TEST_TOKENIZER = False
if TEST_TOKENIZER: 
  kg = CLEVRER_KG_Utils.parseKG_from_file("gdrive/MyDrive/AGNNs/small_copy.ttl")
  kg_questions = CLEVRER_KG_Utils.get_questions(kg)
  kg_choices = CLEVRER_KG_Utils.get_correct_choices(kg,kg_questions)
  t = Tokenizer(kg_questions,kg_choices)
  print ('\n'+t.decode(t.encode("What will happen without the green cube?")))
  del kg; del kg_questions; del kg_choices; del t; gc.collect()

TEST_DATALOADER = False
if TEST_DATALOADER: 
  from random import choice
  kg = CLEVRER_KG_Utils.parseKG_from_file("gdrive/MyDrive/AGNNs/small_copy.ttl")
  kg_questions = CLEVRER_KG_Utils.get_questions(kg)
  kg_choices = CLEVRER_KG_Utils.get_correct_choices(kg,kg_questions)
  t = Tokenizer(kg_questions,kg_choices)
  dl = Dataloader(t)
  dl.get_batch(kg_questions,kg_choices)
  k = choice(range(len(kg_questions))) #print kth question data point
  print (dl.data[k][0][1])
  print (dl.data[k][-1])
  del kg; del kg_questions; del kg_choices; del t; del dl; del k; gc.collect()

TEST_GENERATOR_FORWARD_PASS = False
if TEST_GENERATOR_FORWARD_PASS:
  from random import choice
  kg = CLEVRER_KG_Utils.parseKG_from_file("gdrive/MyDrive/AGNNs/small_copy.ttl")
  kg_questions = CLEVRER_KG_Utils.get_questions(kg)
  kg_choices = CLEVRER_KG_Utils.get_correct_choices(kg,kg_questions)
  t = Tokenizer(kg_questions,kg_choices)
  dl = Dataloader(t)
  dl.get_batch(kg_questions,kg_choices)
  k = choice(range(len(kg_questions)))   #choose a random k
  datapoint = dl.data[k]
  g = Generator(vocab_size=t.vocab_size,
                emb_size = 96,
                context_size = 100,
                n_heads = 12,
                kg = kg)
  print (g(datapoint))
  del kg; del kg_questions; del kg_choices; del t; del dl; del k; del g; gc.collect()

TEST_GENERATOR_TRAINING = False
if TEST_GENERATOR_TRAINING:
  kg = CLEVRER_KG_Utils.parseKG_from_file("gdrive/MyDrive/AGNNs/small_copy.ttl")
  kg_questions = CLEVRER_KG_Utils.get_questions(kg)
  kg_choices = CLEVRER_KG_Utils.get_correct_choices(kg,kg_questions)
  t = Tokenizer(kg_questions,kg_choices)
  dl = Dataloader(t)
  g = Generator(vocab_size=t.vocab_size,
                emb_size = 96,
                context_size = 100,
                n_heads = 12,
                kg = kg)
  g.train(kg_questions,kg_choices,dl,batch_size = 32,epochs=100,kg_perc=0.1)
  del kg; del kg_questions; del kg_choices; del t; del dl; del g; gc.collect()
